{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aJ_pmgxvGur9"
   },
   "source": [
    "# Assignment 1 - Autoregressive models with Transformers\n",
    "## Deep Generative Modeling teaching material"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mEneMITS2agU"
   },
   "source": [
    "#### Instructions on how to use this notebook:\n",
    "\n",
    "This notebook is hosted on ``Google Colab``. To be able to work on it, you have to create your own copy. Go to *File* and select *Save a copy in Drive*.\n",
    "\n",
    "You can also avoid using ``Colab`` entirely, and download the notebook to run it on your own machine. If you choose this, go to *File* and select *Download .ipynb*.\n",
    "\n",
    "The advantage of using **Colab** is that you can use a GPU. You can complete this assignment with a CPU, but it will take a bit longer. Furthermore, we encourage you to train using the GPU not only for faster training, but also to get experience with this setting. This includes moving models and tensors to the GPU and back. This experience is very valuable because for various models and large datasets (like large CNNs for ImageNet, or Transformer models trained on Wikipedia), training on GPU is the only feasible way.\n",
    "\n",
    "The default ``Colab`` runtime does not have a GPU. To change this, go to *Runtime - Change runtime type*, and select *GPU* as the hardware accelerator. The GPU that you get changes according to what resources are available at the time, and its memory can go from a 5GB, to around 18GB if you are lucky. If you are curious, you can run the following in a code cell to check:\n",
    "\n",
    "```sh\n",
    "!nvidia-smi\n",
    "```\n",
    "\n",
    "Note that despite the name, ``Google Colab`` does  not support collaborative work without issues. When two or more people edit the notebook concurrently, only one version will be saved. You can choose to do group programming with one person sharing the screen with the others, or make multiple copies of the notebook to work concurrently.\n",
    "\n",
    "**Submission:** Please bring your (partial) solution to instruction sessions. Then you can discuss it with intructors and your colleagues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lBgoJIpdLI2Y"
   },
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tsdc7fDp40rQ"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "In this assignment, we are going to implement an autoregressive model (ARM). An AMR is a likelihood-based deep generative model that utilizes the product rule and generates new object one-by-one. Transformers are current state-of-the-art architectures used for Large Language Models (LLMs). Specifically, generative LLMs are parameterized by so called decoder-transformers. The model used in this assignemnt is based on the architecture of so called Generative Pretrained Transformers (GPTs):\n",
    "- [Radford, A., Narasimhan, K., Salimans, T. and Sutskever, I., 2018. Improving language understanding by generative pre-training.](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)\n",
    "\n",
    "You can read more about ARMs in Chapter 2 of the following book:\n",
    "- [Tomczak, J.M., \"Deep Generative Modeling\", Springer, 2022](https://link.springer.com/book/10.1007/978-3-030-93158-2)\n",
    "\n",
    "You can read more about transformers in Chapter 12 of the following book:\n",
    "- [Prince, S.J.D., \"Understanding Deep Learning\", MIT Press, 2023](https://udlbook.github.io/udlbook/)\n",
    "\n",
    "In particular, the goals of this assignment are the following:\n",
    "\n",
    "- Understand how transformer-based ARMs are formulated.\n",
    "- Implement components of transformer-based ARMs using PyTorch.\n",
    "- Train and evaluate a transformer-based ARM for text data.\n",
    "\n",
    "This notebook is essential for preparing a report. Moreover, please remember to submit the final notebook together with the report (PDF)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RvsuVNczG6pP"
   },
   "source": [
    "### Theory behind ARMs\n",
    "\n",
    "Let us consider a high-dimensional random variable $\\mathbf{x} \\in \\mathcal{X}^{T}$ where $\\mathcal{X} = \\{0,1,\\dots , L-1\\}$ or $\\mathcal{X} = \\mathbb{R}$. Our goal is to model $p(\\mathbf{x})$. We can apply the product rule to express this distribution as follows:\n",
    "$$\n",
    "p(\\mathbf{x}) = p(x_1) \\prod_{t=2}^{T} p(x_{t}|\\mathbf{x}_{<t}) ,\n",
    "$$\n",
    "where $\\mathbf{x}_{<t} = [x_1, x_2, \\ldots , x_{t-1}]^{\\top}$. For instance, for $\\mathbf{x} = [x_1, x_2, x_{3}]^{\\top}$, we have $p(\\mathbf{x}) = p(x_1) p(x_{2}|x_{1}) p(x_{3} | x_{1}, x_{2})$.\n",
    "\n",
    "The generative procedure is straightforward: We start with $x_1 \\sim p(x_1)$, and then we proceed with $x_t \\sim p(x_{t}|\\mathbf{x}_{<t})$ by plugging in all previously sampled variables $\\mathbf{x}_{<t}$. We can think of this procedure as a for-loop.\n",
    "\n",
    "Now, the main goal is how to parameterize conditional distributions $p(x_{t}|\\mathbf{x}_{<t})$. We can accomplish that by using neural networks, in particular, transformers. In this assignment, we focus on <i>decoder transformers</i> that utilize causal multi-head self-attention."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sIaNyIwxSfN_"
   },
   "source": [
    "### Note\n",
    "\n",
    "In this assignment, we build a simple LLM model. For this purpose, we use a dataset consisting of $\\sim 8.5$k newspaper headlines, and each headline contain at most 150 letters (tokens). You are provided with a tokenizer for turning characters into a sequence of integers and padding, and text processing functions (e.g., removing special characters). Your model will be trained with 1.3M tokens per iteration, and will consist of few millions to over dozen millions of weights.\n",
    "\n",
    "These numbers do not necessarilly impress anyone in the LLM community. However, please be aware that such datasets and models are not small and could be treated as a small-sized LLM-based problems. As you will notice in the end, we can still observe similar phenomena like hallucinations and the power of scaling up."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "suzhlbWqxtD9"
   },
   "source": [
    "## IMPORTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BjxkigYLxpB7"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE!\n",
    "import os\n",
    "\n",
    "import pickle\n",
    "\n",
    "import spacy\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "!pip install datasets\n",
    "from datasets import load_dataset\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "!pip install pytorch_model_summary\n",
    "from pytorch_model_summary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Cm23hRm6CqGh"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "# Check if GPU is available and determine the device\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "else:\n",
    "    device = 'cpu'\n",
    "\n",
    "print(f'The available device is {device}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "81CxONpmMulC"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE! (unless you work locally)\n",
    "# mount drive: WE NEED IT FOR SAVING IMAGES! NECESSARY FOR GOOGLE COLAB!\n",
    "from google.colab import drive\n",
    "drive.mount('/content/gdrive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OoPb92zNM4UY"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE! (unless you work locally)\n",
    "# PLEASE CHANGE IT TO YOUR OWN GOOGLE DRIVE OR YOUR LOCAL DIR!\n",
    "results_model_dir = '/content/gdrive/My Drive/Colab/Results/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I3zs31tOyCmq"
   },
   "source": [
    "## Auxiliary classes and functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DF0agzL7tDHK"
   },
   "source": [
    "Let us define some useful classes:\n",
    "1. DataProcessor: \"cleaning\" texts.\n",
    "2. Tokenizer: transforming characters to integers and padding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LIBNVRNJtHSd"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "class DataProcessor(object):\n",
    "    def __init__(self, ):\n",
    "        super().__init__()\n",
    "        nlp = spacy.load(\"en_core_web_sm\")\n",
    "        nltk.download('omw-1.4')\n",
    "        nltk.download(\"punkt\")\n",
    "        nltk.download(\"wordnet\")\n",
    "        nltk.download(\"stopwords\")\n",
    "\n",
    "    @staticmethod\n",
    "    def preprocess_text(text):\n",
    "        # Tokenize, remove punctuation and lowercase\n",
    "        tokens = nltk.word_tokenize(text)\n",
    "        tokens = [word.lower() for word in tokens if word.isalpha()]\n",
    "\n",
    "        # Remove stopwords and lemmatize\n",
    "        stop_words = set(stopwords.words(\"english\"))\n",
    "        lemmatizer = WordNetLemmatizer()\n",
    "        processed_text = [\n",
    "            lemmatizer.lemmatize(word) for word in tokens if word not in stop_words\n",
    "        ]\n",
    "\n",
    "        return \" \".join(processed_text)\n",
    "\n",
    "    def process_batch(self, texts):\n",
    "        return [self.preprocess_text(d) for d in texts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4NjoCwjV3TN7"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "class Tokenizer(object):\n",
    "    def __init__(self, max_length=0):\n",
    "        super().__init__()\n",
    "\n",
    "        self.max_length = max_length\n",
    "\n",
    "        self.alphabet_letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n",
    "\n",
    "        self.alphabet = self.prepare_alphabet()\n",
    "        self.decoded_alphabet = self.prepare_decoded_alphabet()\n",
    "\n",
    "    def prepare_alphabet(self):\n",
    "        # PREPARE THE ALPHABET (CHAR->INT)\n",
    "        # as a dictionary\n",
    "        alphabet = {}\n",
    "        alphabet['pad'] = 0  # add 'pad'\n",
    "        count = 1\n",
    "\n",
    "        for letter in self.alphabet_letters:\n",
    "            alphabet[letter] = count\n",
    "            count += 1\n",
    "\n",
    "        # add ' ', 'cls' tokens\n",
    "        alphabet[' '] = count\n",
    "        alphabet['cls'] = count + 1\n",
    "\n",
    "        return alphabet\n",
    "\n",
    "    def prepare_decoded_alphabet(self):\n",
    "        # PREPARE DECODED ALPHABET (INT->CHAR)\n",
    "        decoded_alphabet_ints = [i for i in range(len(self.alphabet_letters))]\n",
    "\n",
    "        decoded_alphabet = {}\n",
    "        decoded_alphabet[0] = 'pad'\n",
    "\n",
    "        for i in decoded_alphabet_ints:\n",
    "            decoded_alphabet[i+1] = self.alphabet_letters[i]\n",
    "\n",
    "            decoded_alphabet[i+2] = ' '\n",
    "        decoded_alphabet[i+3] = 'cls'\n",
    "\n",
    "        return decoded_alphabet\n",
    "\n",
    "    def encode(self, texts):\n",
    "        N = len(texts)\n",
    "\n",
    "        if self.max_length == 0:\n",
    "            max_length = 0\n",
    "            for i in range(N):\n",
    "                len_i = len(texts[i])\n",
    "                if len_i > max_length:\n",
    "                    max_length = len_i\n",
    "        else:\n",
    "            max_length = self.max_length\n",
    "\n",
    "        tokens = np.zeros((N, max_length+1))\n",
    "\n",
    "        for i in range(N):\n",
    "            len_i = len(texts[i])\n",
    "            for j in range(-1, max_length):\n",
    "                if j == -1:\n",
    "                    tokens[i,j+1] = self.alphabet['cls']\n",
    "                elif j >= len_i:\n",
    "                    tokens[i,j+1] = self.alphabet['pad']\n",
    "                else:\n",
    "                    if texts[i][j] == 'é':\n",
    "                        tokens[i,j+1] = self.alphabet['e']\n",
    "                    elif texts[i][j] == 'í':\n",
    "                        tokens[i,j+1] = self.alphabet['e']\n",
    "                    elif texts[i][j] == 'á':\n",
    "                        tokens[i,j+1] = self.alphabet['a']\n",
    "                    elif texts[i][j] == 'ó':\n",
    "                        tokens[i,j+1] = self.alphabet['o']\n",
    "                    elif texts[i][j] == 'æ':\n",
    "                        tokens[i,j+1] = self.alphabet['a']\n",
    "                    elif texts[i][j] == 'ä':\n",
    "                        tokens[i,j+1] = self.alphabet['a']\n",
    "                    else:\n",
    "                        tokens[i,j+1] = self.alphabet[texts[i][j]]\n",
    "\n",
    "        return tokens\n",
    "\n",
    "    def decode(self, tokens):\n",
    "        texts = []\n",
    "\n",
    "        for i in range(len(tokens)):\n",
    "            tokens_i = tokens[i,:]\n",
    "            text_i = ''\n",
    "            for j in range(len(tokens_i)):\n",
    "                if tokens_i[j] == 0:\n",
    "                    break\n",
    "                else:\n",
    "                    if self.decoded_alphabet[tokens_i[j]] != 'cls':\n",
    "                        text_i += self.decoded_alphabet[tokens_i[j]]\n",
    "            texts.append(text_i)\n",
    "\n",
    "        return texts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VhiWi-j3mELC"
   },
   "source": [
    "Some useful functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MoB-RuczmGlT"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "def save_texts(sampled_texts, name=''):\n",
    "    # open file in write mode\n",
    "    with open(results_dir + '/samples_' + name + '.txt', 'w') as fp:\n",
    "        for item in sampled_texts:\n",
    "            # write each item in a new line\n",
    "            fp.write(\"%s\\n\" % item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "q60onqQR3TN8"
   },
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0K9XaWO_3TN8"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "class Headers(Dataset):\n",
    "    \"\"\"A simple dataset based on headers. Source: https://huggingface.co/datasets/IlyaGusev/headline_cause\"\"\"\n",
    "\n",
    "    def __init__(self, dataprocessor, tokenizer, mode='train', num_training_data=None, transforms=None):\n",
    "        # LOAD DATA\n",
    "        dataset = load_dataset(\"IlyaGusev/headline_cause\", \"en_simple\")\n",
    "\n",
    "        # PREPARE DATA\n",
    "        if mode == 'train':\n",
    "            train_texts = dataprocessor.process_batch(dataset['train'][:]['left_title'] + dataset['train'][:]['right_title']) # list\n",
    "            if num_training_data is None:\n",
    "                self.data = torch.from_numpy(tokenizer.encode(train_texts)).long()\n",
    "            else:\n",
    "                self.data = torch.from_numpy(tokenizer.encode(train_texts))[:num_training_data].long()\n",
    "        elif mode == 'val':\n",
    "            validation_texts = dataprocessor.process_batch(dataset['validation'][:]['left_title'] + dataset['validation'][:]['right_title']) # list\n",
    "            self.data = torch.from_numpy(tokenizer.encode(validation_texts)).long()\n",
    "        else:\n",
    "            test_texts = dataprocessor.process_batch(dataset['test'][:]['left_title'] + dataset['test'][:]['right_title']) # list\n",
    "            self.data = torch.from_numpy(tokenizer.encode(test_texts)).long()\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q2LLOs0kn7iw"
   },
   "source": [
    "## Implementing ARMs with Transformers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7cXhOwKAzW6Z"
   },
   "source": [
    "### Loss Function (NLL)\n",
    "Our loss function is the negative log-likelihood for the categorical distribution (i.e., the cross-entropy loss).\n",
    "\n",
    "Please note how it is implemented and how tokens (T) are handled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MrwQXSuEoFfH"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "class LossFun(nn.Module):\n",
    "    def __init__(self,):\n",
    "        super().__init__()\n",
    "\n",
    "        self.loss = nn.NLLLoss(reduction='none')\n",
    "\n",
    "    def forward(self, y_model, y_true, reduction='sum'):\n",
    "        # y_model: B(atch) x T(okens) x V(alues)\n",
    "        # y_true: B x T\n",
    "        B, T, V = y_model.size()\n",
    "\n",
    "        y_model = y_model.view(B * T, V)\n",
    "        y_true = y_true.view(B * T,)\n",
    "\n",
    "        loss_matrix = self.loss(y_model, y_true) # B*T\n",
    "\n",
    "        if reduction == 'sum':\n",
    "            return torch.sum(loss_matrix)\n",
    "        elif reduction == 'mean':\n",
    "            loss_matrix = loss_matrix.view(B, T)\n",
    "            return torch.mean(torch.sum(loss_matrix, 1))\n",
    "        else:\n",
    "            raise ValueError('Reduction could be either `sum` or `mean`.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nhNTy5mn0XDT"
   },
   "source": [
    "### Transformer block\n",
    "\n",
    "Transformers consist of transformer block. In the cell below, please define a transformer block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9vTmKHwrpUVa"
   },
   "outputs": [],
   "source": [
    "# YOUR CODE GOES HERE\n",
    "# NOTE: The class must containt the following elements:\n",
    "# (i) components (nn.Module) of a transformer bloc\n",
    "# (ii) the forward function\n",
    "# Moreover, forward must return the processed input\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, num_emb, num_neurons, num_heads=4):\n",
    "        super().__init__()\n",
    "\n",
    "        # hyperparams\n",
    "        self.D = num_emb  # the embedding size\n",
    "        self.H = num_heads  # the number of heads\n",
    "        self.neurons = num_neurons  # the number of neurons in MLP\n",
    "\n",
    "        # components\n",
    "\n",
    "        #\n",
    "        # your code goes here\n",
    "        #\n",
    "        # please define all necessary components that constitute a transformer block\n",
    "        # NOTE: please use hyperparams defined above\n",
    "\n",
    "    def forward(self, x, causal=True):\n",
    "        #\n",
    "        # your code goes here\n",
    "        #\n",
    "        # please define the forward pass for a tranformer block"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CLIEwIiw00op"
   },
   "source": [
    "### ARM (Decoder-Transformer)\n",
    "\n",
    "Once we have a class for transformer blocks, we need to define a decoder-transformer that defines an auto-regressive model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xQIvee5Cp69V"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "class DecoderTransformer(nn.Module):\n",
    "    def __init__(self, num_tokens, num_token_vals, num_emb, num_neurons, num_heads=2, dropout_prob=0.1, num_blocks=10, device='cpu'):\n",
    "        super().__init__()\n",
    "\n",
    "        # hyperparams\n",
    "        self.device = device\n",
    "        self.num_tokens = num_tokens\n",
    "        self.num_token_vals = num_token_vals\n",
    "        self.num_emb = num_emb\n",
    "        self.num_blocks = num_blocks\n",
    "\n",
    "        # embedding layer\n",
    "        self.embedding = torch.nn.Embedding(num_token_vals, num_emb)\n",
    "\n",
    "        # positional embedding\n",
    "        self.positional_embedding = nn.Embedding(num_tokens, num_emb)\n",
    "\n",
    "        # transformer blocks\n",
    "        self.transformer_blocks = nn.ModuleList()\n",
    "        for _ in range(num_blocks):\n",
    "            self.transformer_blocks.append(TransformerBlock(num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads))\n",
    "\n",
    "        # output layer (logits + softmax)\n",
    "        self.logits = nn.Sequential(nn.Linear(num_emb, num_token_vals))\n",
    "\n",
    "        # dropout layer\n",
    "        self.dropout = nn.Dropout(dropout_prob)\n",
    "\n",
    "        # loss function\n",
    "        self.loss_fun = LossFun()\n",
    "\n",
    "    def transformer_forward(self, x, causal=True, temperature=1.0):\n",
    "        # x: B(atch) x T(okens)\n",
    "        # embedding of tokens\n",
    "        x = self.embedding(x) # B x T x D\n",
    "        # embedding of positions\n",
    "        pos = torch.arange(0, x.shape[1], dtype=torch.long).unsqueeze(0).to(self.device)\n",
    "        pos_emb = self.positional_embedding(pos)\n",
    "        # dropout of embedding of inputs\n",
    "        x = self.dropout(x + pos_emb)\n",
    "\n",
    "        # transformer blocks\n",
    "        for i in range(self.num_blocks):\n",
    "            x = self.transformer_blocks[i](x)\n",
    "\n",
    "        # output logits\n",
    "        out = self.logits(x)\n",
    "\n",
    "        return F.log_softmax(out/temperature, 2)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sample(self, batch_size=4, temperature=1.0):\n",
    "        x_seq = np.asarray([[self.num_token_vals - 1] for i in range(batch_size)])\n",
    "\n",
    "        # sample next tokens\n",
    "        for i in range(self.num_tokens-1):\n",
    "            xx = torch.tensor(x_seq, dtype=torch.long, device=self.device)\n",
    "            # process x and calculate log_softmax\n",
    "            x_log_probs = self.transformer_forward(xx, temperature=temperature)\n",
    "            # sample i-th tokens\n",
    "            x_i_sample = torch.multinomial(torch.exp(x_log_probs[:,i]), 1).to(self.device)\n",
    "            # update the batch with new samples\n",
    "            x_seq = np.concatenate((x_seq, x_i_sample.to('cpu').detach().numpy()), 1)\n",
    "\n",
    "        return x_seq\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def top1_rec(self, x, causal=True):\n",
    "        x_prob = torch.exp(self.transformer_forward(x, causal=True))[:,:-1,:].contiguous()\n",
    "        _, x_rec_max = torch.max(x_prob, dim=2)\n",
    "        return torch.sum(torch.mean((x_rec_max.float() == x[:,1:].float().to(device)).float(), 1).float())\n",
    "\n",
    "    def forward(self, x, causal=True, temperature=1.0, reduction='mean'):\n",
    "        # get log-probabilities\n",
    "        log_prob = self.transformer_forward(x, causal=causal, temperature=temperature)\n",
    "\n",
    "        return self.loss_fun(log_prob[:,:-1].contiguous(), x[:,1:].contiguous(), reduction=reduction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hLhgze7DA4yx"
   },
   "source": [
    "### Evaluation and training functions\n",
    "\n",
    "**Please DO NOT remove or modify them.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "I9Dr3a6lqJ0W"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "def evaluation(test_loader, name=None, model_best=None, epoch=None, device='cuda'):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model').to(device)\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    rec = 1.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = model_best.forward(test_batch.to(device), reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "\n",
    "        rec_t = model_best.top1_rec(test_batch.to(device))\n",
    "        rec = rec + rec_t.item()\n",
    "\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "    rec = rec / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}, rec={rec}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}, val rec={rec}')\n",
    "\n",
    "    return loss, rec\n",
    "\n",
    "def plot_curve(name, nll_val, ylabel='nll'):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.savefig(name + '_' + ylabel + '_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9ABgMeG0qFAP"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader, device='cuda'):\n",
    "    nll_val = []\n",
    "    rec_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            loss = model.forward(batch.to(device))\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val, r_val = evaluation(val_loader, model_best=model, epoch=e, device=device)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "        rec_val.append(r_val)\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "\n",
    "            sampled_tokens = model.sample(batch_size=64, temperature=1.0)\n",
    "            sampled_texts = tokenizer.decode(sampled_tokens)\n",
    "            save_texts(sampled_texts, name='epoch_' + str(e))\n",
    "\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "\n",
    "                sampled_tokens = model.sample(batch_size=64, temperature=1.0)\n",
    "                sampled_texts = tokenizer.decode(sampled_tokens)\n",
    "                save_texts(sampled_texts, name='epoch_' + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "    rec_val = np.asarray(rec_val)\n",
    "\n",
    "    np.save(name + '_nll_val.npy', nll_val)\n",
    "    np.save(name + '_rec_val.npy', rec_val)\n",
    "\n",
    "    return nll_val, rec_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kWr8N2u2qNTu"
   },
   "source": [
    "### Setup\n",
    "\n",
    "**NOTE: *Please comment your code! Especially if you introduce any new variables (e.g., hyperparameters).***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bFTE5jtYpxDV"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "dataprocessor = DataProcessor()\n",
    "tokenizer = Tokenizer(max_length=149)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UXHitzrYqNhY"
   },
   "outputs": [],
   "source": [
    "# PLEASE MODIFY ACCORDING TO THE REPORT REQUIREMENTS\n",
    "num_training_data = 1000  # None to take all training data\n",
    "\n",
    "# DO NOT REMOVE OR MODIFY THE REST OF THIS CELL\n",
    "#-dataset\n",
    "train_dataset = Headers(dataprocessor, tokenizer, num_training_data=num_training_data, mode=\"train\")\n",
    "validation_dataset = Headers(dataprocessor, tokenizer, mode=\"val\")\n",
    "test_dataset = Headers(dataprocessor, tokenizer, mode=\"test\")\n",
    "\n",
    "#-dataloaders\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "training_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OTBFl-bNb0s9"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE (but you can modify if necessary)\n",
    "#-creating a dir for saving results\n",
    "name = 'arm_transformer'  # NOTE: if you run multiple experiments, you would overwrite results. Please modify this part if necessary.\n",
    "results_dir = results_model_dir + name + '/'\n",
    "if not(os.path.exists(results_dir)):\n",
    "  os.mkdir(results_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kmKDXMI0B231"
   },
   "source": [
    "In the next cell, please initialize the model. Please remember about commenting your code!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b73aaBDxqSYb"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE but PLEASE MODIFY WHENEVER YOU ARE ASKED FOR IT!\n",
    "# NOTE: in order to obtain required sizes of your models, you can play with\n",
    "#       various values of num_neurons, num_heads, num_blocks, num_emb\n",
    "num_tokens = 150 # do not modify!\n",
    "num_token_vals = 29  # do not modify!\n",
    "num_neurons = 16 # please modify it\n",
    "num_heads = 4 # please modify it\n",
    "num_blocks = 4 # please modify it\n",
    "num_emb = num_heads * 4  # please modify it but it must be a multiplication of num_heads\n",
    "causal=True # do not modify!\n",
    "\n",
    "lr = 1e-3 # learning rate; do not modify!\n",
    "num_epochs = 1000 # max. number of epochs; do not modify!\n",
    "max_patience = 10 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped; do not modify!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JQy_iNJs3TOD"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "model = DecoderTransformer(num_tokens=num_tokens, num_token_vals=num_token_vals, num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads, num_blocks=num_blocks, device=device)\n",
    "model = model.to(device)\n",
    "# Print the summary (like in Keras)\n",
    "print(summary(model, torch.zeros(1, num_tokens, dtype=torch.long).to(device), show_input=False, show_hierarchical=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iC8AkWt4CURT"
   },
   "source": [
    "Please initialize the optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "a3nTSDe7ql08"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P5GrzUcHFweG"
   },
   "source": [
    "### Training and final evaluation\n",
    "\n",
    "In the following two cells, we run the training and the final evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VD7WuY6bqnBK"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "# Training procedure\n",
    "nll_val, rec_val = training(name=results_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer, training_loader=training_loader, val_loader=val_loader, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JAuMt9_wquOI"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE OR MODIFY\n",
    "# Final evaluation\n",
    "test_loss, test_rec = evaluation(name=results_dir + name, test_loader=, device=device)\n",
    "\n",
    "with open(results_dir + name + '_test_loss.txt', \"w\") as f:\n",
    "    f.write('Test NLL: ' + str(test_loss)+'\\n'+'Test REC: ' + str(test_rec))\n",
    "    f.close()\n",
    "\n",
    "plot_curve(results_dir + name, nll_val, ylabel='nll')\n",
    "plot_curve(results_dir + name, rec_val, ylabel='rec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LcQ0GPZy3TOE"
   },
   "outputs": [],
   "source": [
    "# DO NOT REMOVE\n",
    "# Sample texts: load best model\n",
    "model_best = torch.load(results_dir + name + '.model')\n",
    "model_best = model_best.eval()\n",
    "\n",
    "# sample\n",
    "temperature = 1.0 # you can modify it\n",
    "num_samples = 64 # you can modify it\n",
    "\n",
    "sampled_tokens = model_best.sample(batch_size=num_samples, temperature=temperature)  # do not modify\n",
    "sampled_texts = tokenizer.decode(sampled_tokens)  # do not modify\n",
    "\n",
    "save_texts(sampled_texts, name='FINAL_' + str(temperature))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": [
    {
     "file_id": "1C6CaZsSdXdc5fONOVrJy5iYKzr2DIg9j",
     "timestamp": 1607008727578
    }
   ],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
